{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "“TextRNN-Torch.ipynb”的副本",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "CWaBEkKSbLAg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "'''\n",
        "  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor\n",
        "'''\n",
        "import torch\n",
        "import numpy as np\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.utils.data as Data\n",
        "\n",
        "dtype = torch.FloatTensor"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4n7bPYtCcIWU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "sentences = [\"i like dog\", \"i love coffee\", \"i hate milk\"]\n",
        "\n",
        "word_list = \" \".join(sentences).split()\n",
        "vocab = list(set(word_list))\n",
        "word2idx = {w: i for i, w in enumerate(vocab)}\n",
        "idx2word = {i: w for i, w in enumerate(vocab)}\n",
        "n_class = len(vocab)\n",
        "\n",
        "# TextRNN Parameter\n",
        "batch_size = 2\n",
        "n_step = 2 # number of cells(= number of Step)\n",
        "n_hidden = 5 # number of hidden units in one cell\n",
        "\n",
        "def make_data(sentences):\n",
        "    input_batch = []\n",
        "    target_batch = []\n",
        "\n",
        "    for sen in sentences:\n",
        "        word = sen.split()\n",
        "        input = [word2idx[n] for n in word[:-1]]\n",
        "        target = word2idx[word[-1]]\n",
        "\n",
        "        input_batch.append(np.eye(n_class)[input])\n",
        "        target_batch.append(target)\n",
        "\n",
        "    return input_batch, target_batch\n",
        "\n",
        "input_batch, target_batch = make_data(sentences)\n",
        "input_batch, target_batch = torch.Tensor(input_batch), torch.LongTensor(target_batch)\n",
        "dataset = Data.TensorDataset(input_batch, target_batch)\n",
        "loader = Data.DataLoader(dataset, batch_size, True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cz38D1eC8jER",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "dc97488e-f4ae-40bd-9795-95676d6de3aa"
      },
      "source": [
        "# print(input_batch)\n",
        "# print(input_batch.shape)\n",
        "# print(np.eye(7)[[1, 2, 5]])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "torch.Size([3, 2, 7])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SnPg909Gc-OC",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class TextRNN(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(TextRNN, self).__init__()\n",
        "        self.rnn = nn.RNN(input_size=n_class, hidden_size=n_hidden)\n",
        "        # input_size 指的是每个单词用多少维的向量去编码\n",
        "        # hidden_size 指的是输出维度是多少\n",
        "        # fully connected layer\n",
        "        self.fc = nn.Linear(n_hidden, n_class)\n",
        "\n",
        "    def forward(self, hidden, X):\n",
        "        # X: [batch_size, n_step, n_class]\n",
        "        X = X.transpose(0, 1) # X : [n_step, batch_size, n_class]\n",
        "        out, hidden = self.rnn(X, hidden)\n",
        "        # out : [n_step, batch_size, num_directions(=1) * n_hidden]\n",
        "        # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
        "        out = out[-1] # [batch_size, num_directions(=1) * n_hidden] ⭐\n",
        "        model = self.fc(out)\n",
        "        return model\n",
        "\n",
        "model = TextRNN()\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "optimizer = optim.Adam(model.parameters(), lr=0.001)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "32Fac6LDZJmE",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 185
        },
        "outputId": "43d5f1d9-ef96-47fd-9616-a0fce675c643"
      },
      "source": [
        "# Training\n",
        "for epoch in range(5000):\n",
        "    for x, y in loader:\n",
        "      # hidden : [num_layers * num_directions, batch, hidden_size]\n",
        "      hidden = torch.zeros(1, x.shape[0], n_hidden) # h0\n",
        "      # x : [batch_size, n_step, n_class]\n",
        "      pred = model(hidden, x)\n",
        "\n",
        "      # pred : [batch_size, n_class], y : [batch_size] (LongTensor, not one-hot)\n",
        "      loss = criterion(pred, y)\n",
        "      if (epoch + 1) % 1000 == 0:\n",
        "          print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
        "\n",
        "      optimizer.zero_grad()\n",
        "      loss.backward()\n",
        "      optimizer.step()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 1000 cost = 0.045840\n",
            "Epoch: 1000 cost = 0.059695\n",
            "Epoch: 2000 cost = 0.009017\n",
            "Epoch: 2000 cost = 0.009856\n",
            "Epoch: 3000 cost = 0.003041\n",
            "Epoch: 3000 cost = 0.002324\n",
            "Epoch: 4000 cost = 0.001052\n",
            "Epoch: 4000 cost = 0.000802\n",
            "Epoch: 5000 cost = 0.000340\n",
            "Epoch: 5000 cost = 0.000372\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "V2qpik7_cEre",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "acb92763-f7d5-4a7d-bec2-89a76cb3269d"
      },
      "source": [
        "input = [sen.split()[:2] for sen in sentences]\n",
        "# Predict\n",
        "hidden = torch.zeros(1, len(input), n_hidden)\n",
        "predict = model(hidden, input_batch).data.max(1, keepdim=True)[1]\n",
        "print([sen.split()[:2] for sen in sentences], '->', [idx2word[n.item()] for n in predict.squeeze()])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[['i', 'like'], ['i', 'love'], ['i', 'hate']] -> ['dog', 'coffee', 'milk']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LL520CdCRK9h",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}